define a simple spline¶

In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm


def bspline_basis_matrix():
    """Returns the cubic B-spline basis matrix (4x4)"""
    return (1.0 / 6.0) * torch.tensor([
        [-1,  3, -3, 1],
        [ 3, -6,  3, 0],
        [-3,  0,  3, 0],
        [ 1,  4,  1, 0]
    ])


class ClampedBSplineTrajectoryOptimizer:
    def __init__(self, start, goal, num_internal_ctrl_pts=6, dim=2, lr=0.05, device='cpu'):
        """
        Uses clamped B-spline with repeated endpoints to ensure path starts at `start` and ends at `goal`.

        Args:
            start, goal: Tensors of shape (dim,)
            num_internal_ctrl_pts: Number of learnable internal control points
        """
        self.device = device
        self.dim = dim
        self.start = start.to(device)
        self.goal = goal.to(device)

        # Internal control points (learnable)
        internal = torch.linspace(0, 1, num_internal_ctrl_pts + 2, device=device).unsqueeze(1)
        internal = internal[1:-1]  # exclude endpoints
        internal_ctrl = internal * (self.goal - self.start) + self.start

        eps = 0.05 * torch.randn_like(internal_ctrl)
        internal_ctrl += eps

        self.internal_ctrl_pts = nn.Parameter(internal_ctrl)
        self.lr = lr
        self.optimizer = optim.Adam([self.internal_ctrl_pts], lr=lr)
        self.basis = bspline_basis_matrix().to(device)

    def get_full_ctrl_pts(self):
        # Repeat start and goal 3 times for cubic clamping
        return torch.cat([
            self.start.expand(3, -1),
            self.internal_ctrl_pts,
            self.goal.expand(3, -1)
        ], dim=0)

    def evaluate_spline(self, resolution=100, stochastic=False):
        """
        Returns interpolated points along the clamped B-spline.
        If `stochastic` is True, samples random u values instead of fixed linspace.
        """
        ctrl_pts = self.get_full_ctrl_pts()
        segments = ctrl_pts.shape[0] - 3
        samples_per_segment = resolution // segments
        points = []

        for i in range(segments - 1):
            G = ctrl_pts[i:i + 4]  # 4 control points

            if stochastic:
                u_vals = torch.rand(samples_per_segment, device=self.device)
            else:
                u_vals = torch.linspace(0, 1, samples_per_segment, device=self.device)

            U = torch.stack([u_vals**3, u_vals**2, u_vals, torch.ones_like(u_vals)], dim=1)  # (N, 4)
            segment_points = (U @ self.basis) @ G  # (N, dim)
            points.append(segment_points)

        return torch.cat(points, dim=0)  # shape (total_samples, dim)





def sdf_cost(sdf_values, alpha=35.0, beta=0.1):
    """
    Converts SDF values into cost values.

    sdf_values: Tensor of signed distances at each query point
    alpha: Controls how sharp the penalty is inside the obstacle
    beta: Controls how far the cost decays outside the obstacle

    Returns:
        cost: Tensor of same shape, with high values inside obstacle and decaying outside
    """
    cost = torch.where(
        sdf_values < 0.25,
        torch.exp(-alpha * sdf_values),         # inside: large cost
        beta * torch.exp(-sdf_values / beta)    # outside: decaying cost
    )
    return cost

def spacing_regularizer(ctrl_pts, strength=1.0):
    """
    Penalizes non-uniform spacing between consecutive control points.
    """
    diffs = ctrl_pts[1:] - ctrl_pts[:-1]
    dists = torch.norm(diffs, dim=1)
    mean_dist = dists.mean()
    return strength * ((dists - mean_dist) ** 2).mean()

def repulsion_cost(ctrl_pts, min_dist=0.2, strength=1.0):
    """
    Penalizes control points that get too close to each other.
    Uses a log barrier on distances below `min_dist`.

    ctrl_pts: Tensor of shape (N, D)
    """
    N = ctrl_pts.shape[0]
    # Create indices for all pairs of points
    i, j = torch.triu_indices(N, N, offset=1)

    # Compute distances between all pairs
    dists = torch.norm(ctrl_pts[i] - ctrl_pts[j], dim=1)

    # Apply log barrier only to distances below min_dist
    mask = dists < min_dist
    cost = (-torch.log(dists[mask] / min_dist + 1e-6)).sum()

    return strength * cost
In [30]:
import numpy as np
import plotly.graph_objects as go

def plot_sdf_grid_plotly_2d(env, grid_resolution=100, name="2D SDF Visualization"):
    """
    Plots the SDF grid of the environment using plotly in 2D.
    Args:
        env: An instance of EnvBase or its subclass.
        grid_resolution: Number of points per axis.
    """
    # Define grid
    x = torch.linspace(env.limits_np[0][0], env.limits_np[1][0], grid_resolution, **env.tensor_args)
    y = torch.linspace(env.limits_np[0][1], env.limits_np[1][1], grid_resolution, **env.tensor_args)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    points = torch.stack([X, Y], dim=-1).view(-1, 1, 2)

    # Compute SDF
    sdf = env.compute_sdf(points).view(grid_resolution, grid_resolution)
    sdf = sdf_cost(sdf)
    sdf_np = sdf.cpu().numpy()

    # Plot with plotly
    fig = go.Figure(
        data=go.Contour(
            z=sdf_np,
            x=x.cpu().numpy(),
            y=y.cpu().numpy(),
            colorscale='Viridis',
            # colorbar=dict(title='SDF'),
            colorbar=None,
            line_width=0,
            contours=dict(
                coloring ='heatmap',
                showlabels = False, # show labels on contours
                labelfont = dict( # label font properties
                    size = 12,
                    color = 'white',
                )
            )
        )
    )
    fig.update_layout(
        xaxis_title='X',
        yaxis_title='Y',
        title=name,
        yaxis_scaleanchor="x",
        yaxis_scaleratio=1
    )
    # fig.show()
    return fig

setup planning a sdf environment via spline¶

In [31]:
def plan_env(env2d, name="env"):

    # Visualize SDF grid in 2D
    fig = plot_sdf_grid_plotly_2d(env2d, grid_resolution=100, name=name)

    spline = ClampedBSplineTrajectoryOptimizer(
        start=torch.tensor([-1.0, -1.0]),
        goal=torch.tensor([1.0, 1.0]),
        num_internal_ctrl_pts=16,
        dim=2,
        lr=0.002,
        device=env2d.tensor_args['device'],
    )

    # plot initial path
    initial_spline_points = spline.evaluate_spline().detach().cpu()
    fig.add_trace(go.Scatter(x=initial_spline_points[:, 0], y=initial_spline_points[:, 1], mode='lines', name='Initial Path', line=dict(color='blue')))
    # ctrl_np = spline.get_full_ctrl_pts().detach().cpu().numpy()
    # fig.add_trace(go.Scatter(x=ctrl_np[:, 0], y=ctrl_np[:, 1], mode='markers', name='Control Points', line=dict(color='black', alpha=0.5)))



    with tqdm.trange(200) as t:
        for i in t:
            spline.optimizer.zero_grad()
            spline_points = spline.evaluate_spline(resolution=100, stochastic=True)
            # print(spline_points.shape)
            # print(env2d.compute_sdf(spline_points))
            # exit()

            sdfs = env2d.compute_sdf(spline_points)
            obs_cost = sdf_cost(sdfs).sum()

            even_spline_points = spline.evaluate_spline(resolution=100, stochastic=False)
            len_cost = torch.norm(even_spline_points[1:] - even_spline_points[:-1], dim=1).sum()

            ctrl_pts = spline.get_full_ctrl_pts()
            spacing_cost = spacing_regularizer(ctrl_pts, strength=30)

            cost = (5 * obs_cost + 2 * len_cost + spacing_cost).sum()

            cost.backward()
            spline.optimizer.step()

            t.set_postfix(obs_cost=obs_cost.item(), len_cost=len_cost.item(), spacing_cost=spacing_cost.item())

    # plot final path
    final_spline_points = spline.evaluate_spline().detach().cpu()
    fig.add_trace(go.Scatter(x=final_spline_points[:, 0], y=final_spline_points[:, 1], mode='lines', name='Final Path', line=dict(color='red')))
    ctrl_np = spline.get_full_ctrl_pts().detach().cpu().numpy()
    fig.add_trace(go.Scatter(x=ctrl_np[:, 0], y=ctrl_np[:, 1], mode='markers', name='Control Points', line=dict(color='rgba(255, 125, 0, 0.5)')))


    fig.show()
In [ ]:
 

test each environment¶

In [32]:
from torch_robotics.torch_utils.torch_utils import DEFAULT_TENSOR_ARGS

# 2D environments
from torch_robotics.environments.env_circle_2d import EnvCircle2D
from torch_robotics.environments.env_dense_2d import EnvDense2D
from torch_robotics.environments.env_dense_2d_extra_objects import EnvDense2DExtraObjects
from torch_robotics.environments.env_grid_circles_2d import EnvGridCircles2D
from torch_robotics.environments.env_narrow_passage_dense_2d import EnvNarrowPassageDense2D
from torch_robotics.environments.env_narrow_passage_dense_2d_extra_objects import EnvNarrowPassageDense2DExtraObjects
from torch_robotics.environments.env_planar2link import EnvPlanar2Link
from torch_robotics.environments.env_simple_2d import EnvSimple2D
from torch_robotics.environments.env_simple_2d_extra_objects import EnvSimple2DExtraObjects
from torch_robotics.environments.env_square_2d import EnvSquare2D

# 3D environments
from torch_robotics.environments.env_maze_boxes_3d import EnvMazeBoxes3D
from torch_robotics.environments.env_spheres_3d import EnvSpheres3D
from torch_robotics.environments.env_spheres_3d_extra_objects import EnvSpheres3DExtraObjects
from torch_robotics.environments.env_table_shelf import EnvTableShelf

# List of (name, class, is3d)
envs = [
    # 2D
    ("EnvCircle2D", EnvCircle2D, False),
    ("EnvDense2D", EnvDense2D, False),
    ("EnvDense2DExtraObjects", EnvDense2DExtraObjects, False),
    ("EnvGridCircles2D", EnvGridCircles2D, False),
    ("EnvNarrowPassageDense2D", EnvNarrowPassageDense2D, False),
    ("EnvNarrowPassageDense2DExtraObjects", EnvNarrowPassageDense2DExtraObjects, False),
    ("EnvPlanar2Link", EnvPlanar2Link, False),
    ("EnvSimple2D", EnvSimple2D, False),
    ("EnvSimple2DExtraObjects", EnvSimple2DExtraObjects, False),
    ("EnvSquare2D", EnvSquare2D, False),
    # 3D
    ("EnvMazeBoxes3D", EnvMazeBoxes3D, True),
    ("EnvSpheres3D", EnvSpheres3D, True),
    ("EnvSpheres3DExtraObjects", EnvSpheres3DExtraObjects, True),
    ("EnvTableShelf", EnvTableShelf, True),
]
i = 0
for name, cls, is3d in envs:
    print(f"Plotting {name} ({'3D' if is3d else '2D'}) ...")
    try:
        env = cls(precompute_sdf_obj_fixed=True, sdf_cell_size=0.01, tensor_args=DEFAULT_TENSOR_ARGS)
        if is3d:
            pass
            # plot_sdf_grid_plotly(env, grid_resolution=30, name=name)
        else:
            plan_env(env, name=name)
            # plot_sdf_grid_plotly_2d(env, grid_resolution=100, name=name)
    except Exception as e:
        print(f"Failed to plot {name}: {e}")
Plotting EnvCircle2D (2D) ...
Precomputing the SDF grid and gradients took: 0.044 sec
  0%|          | 0/200 [00:00<?, ?it/s, len_cost=2.95, obs_cost=4e+4, spacing_cost=0.235]
100%|██████████| 200/200 [00:03<00:00, 64.36it/s, len_cost=2.85, obs_cost=587, spacing_cost=0.248]    
Plotting EnvDense2D (2D) ...
Precomputing the SDF grid and gradients took: 0.036 sec
100%|██████████| 200/200 [00:02<00:00, 69.87it/s, len_cost=3.23, obs_cost=17.8, spacing_cost=0.519]
Plotting EnvDense2DExtraObjects (2D) ...
Precomputing the SDF grid and gradients took: 0.033 sec
100%|██████████| 200/200 [00:03<00:00, 55.90it/s, len_cost=3.13, obs_cost=27.8, spacing_cost=0.577]
Plotting EnvGridCircles2D (2D) ...
Precomputing the SDF grid and gradients took: 0.037 sec
100%|██████████| 200/200 [00:03<00:00, 60.96it/s, len_cost=3.76, obs_cost=25.7, spacing_cost=0.499]
Plotting EnvNarrowPassageDense2D (2D) ...
Precomputing the SDF grid and gradients took: 0.033 sec
100%|██████████| 200/200 [00:03<00:00, 59.21it/s, len_cost=3.26, obs_cost=2.56, spacing_cost=0.486]
Plotting EnvNarrowPassageDense2DExtraObjects (2D) ...
Precomputing the SDF grid and gradients took: 0.030 sec
100%|██████████| 200/200 [00:03<00:00, 51.55it/s, len_cost=3.15, obs_cost=5.62, spacing_cost=0.421]
Plotting EnvPlanar2Link (2D) ...
Precomputing the SDF grid and gradients took: 0.017 sec
100%|██████████| 200/200 [00:03<00:00, 57.98it/s, len_cost=3.03, obs_cost=633, spacing_cost=0.474]    
Plotting EnvSimple2D (2D) ...
Precomputing the SDF grid and gradients took: 0.024 sec
100%|██████████| 200/200 [00:03<00:00, 55.31it/s, len_cost=2.99, obs_cost=2.44, spacing_cost=0.421]
Plotting EnvSimple2DExtraObjects (2D) ...
Precomputing the SDF grid and gradients took: 0.020 sec
100%|██████████| 200/200 [00:03<00:00, 55.90it/s, len_cost=3.49, obs_cost=6.48, spacing_cost=0.602]
Plotting EnvSquare2D (2D) ...
Precomputing the SDF grid and gradients took: 0.030 sec
100%|██████████| 200/200 [00:03<00:00, 59.37it/s, len_cost=2.92, obs_cost=4.19e+5, spacing_cost=0.349]
Plotting EnvMazeBoxes3D (3D) ...
Precomputing the SDF grid and gradients took: 0.657 sec
Plotting EnvSpheres3D (3D) ...
Precomputing the SDF grid and gradients took: 0.117 sec
Plotting EnvSpheres3DExtraObjects (3D) ...
Precomputing the SDF grid and gradients took: 0.118 sec
Plotting EnvTableShelf (3D) ...
Precomputing the SDF grid and gradients took: 0.524 sec
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]: